"""
DBInfo 定义了数据库环境,当前版本只考虑MySQL数据库
       基于pymysql从数据库中获取表、字段、索引信息

作者：修炼者 7457222@qq.com
日期：2024-12-10 10:48:19
"""
import urllib.parse
import pymysql
from public.jygt_coder_loger import get_logger

logger = get_logger(__name__)


class DBInfo:
    """
    MySQL数据库配置类，用于存储数据库连接信息。
    """    
    def __init__(self, host, port, user, password, database):
        """
        初始化MySQL数据库配置。
        参数：
            - host: 数据库主机地址
            - port: 数据库端口号
            - user: 数据库用户名
            - password: 数据库密码
            - database: 要连接的数据库名称
        """
        self.host = host
        self.port = port
        self.user = user
        self.password = password
        self.database = database

        # 通过loadInfo加载的table、column、index信息
        self.info = []
    def getSQLAlchemyConnectionString(self):
        """获取SQLAlchemy连接字符串"""
        urlEncodedPassword = urllib.parse.quote_plus(self.password)
        return f"mysql+pymysql://{self.user}:{urlEncodedPassword}@{self.host}:{self.port}/{self.database}"

    def getHost(self):
        return self.host
    def getUser(self):
        return self.user
    
    def getPassword(self):
        return self.password
    
    def getDatabase(self):
        return self.database
    
    def getPort(self):
        return self.port
    def __getLength(self,mysql_type):
        """提取 MySQL 类型中的长度（如果适用）。"""
        if '(' in mysql_type:
            return mysql_type[mysql_type.index('(') + 1: mysql_type.index(')')]
        return None
    def __getSqlalchemyType(self,mysql_type):
        if mysql_type.startswith('int'):
            return 'Integer'
        elif mysql_type.startswith('varchar'):
            return 'String({})'.format(self.__getLength(mysql_type) or 255)  # 默认长度为 255
        elif mysql_type.startswith('char'):
            return 'String({})'.format(self.__getLength(mysql_type) or 1)  # 默认长度为 1
        elif mysql_type.startswith('text'):
            return 'Text'
        elif mysql_type.startswith('blob'):
            return 'LargeBinary'
        elif mysql_type.startswith('datetime'):
            return 'DateTime'
        elif mysql_type.startswith('timestamp'):
            return 'DateTime'
        elif mysql_type.startswith('date'):
            return 'Date'
        elif mysql_type.startswith('time'):
            return 'Time'
        elif mysql_type.startswith('enum'):
            return 'Enum'  # Enum 需要特别处理
        elif mysql_type.startswith('set'):
            return 'Enum'  # Set 也可以视为 Enum
        elif mysql_type.startswith('decimal'):
            return 'Numeric({}, {})'.format(self.__getLength(mysql_type).split(',')[0], self.__getLength(mysql_type).split(',')[1])  # 提取精度和小数位
        elif mysql_type.startswith('float'):
            return 'Float'
        elif mysql_type.startswith('double'):
            return 'Float'  # SQLAlchemy 将 float 和 double 都映射为 Float
        elif mysql_type.startswith('bit'):
            return 'Boolean'
        else:
            return 'String(255)'  # 默认类型
    def __getGqlType(self,mysql_type):
        if mysql_type.startswith('int') or mysql_type.startswith('tinyint'):
            return 'graphene.Int'
        elif mysql_type.startswith('varchar'):
            return 'graphene.String'
        elif mysql_type.startswith('char'):
            return 'graphene.String'
        elif mysql_type.startswith('text'):
            return 'graphene.String'
        elif mysql_type.startswith('blob'):
            return 'graphene.String'
        elif mysql_type.startswith('datetime'):
            return 'graphene.String'
        elif mysql_type.startswith('timestamp'):
            return 'graphene.String'
        elif mysql_type.startswith('date'):
            return 'graphene.String'
        elif mysql_type.startswith('time'):
            return 'graphene.String'
        elif mysql_type.startswith('decimal'):# 这里不正确
            return 'Numeric({}, {})'.format(self.__getLength(mysql_type).split(',')[0], self.__getLength(mysql_type).split(',')[1])  # 提取精度和小数位
        elif mysql_type.startswith('float'):
            return 'graphene.Float'
        elif mysql_type.startswith('double'):
            return 'graphene.Float'  # SQLAlchemy 将 float 和 double 都映射为 Float
        elif mysql_type.startswith('boolean'):
            return 'graphene.Boolean'
        else:
            return 'graphene.String'  # 默认类型
    def __getTableColumns(self,connection, table_name):
        logger.info(f"get table fields for {table_name}")
        try:
            with connection.cursor() as cursor:
                # 查询表结构信息的 SQL 语句
                query = """
                SHOW FULL COLUMNS FROM `{}`;
                """.format(table_name)

                # 执行查询
                cursor.execute(query)
                columns = cursor.fetchall()

                self.__column_graphene_info = {}

                column_info = []
                for column in columns:
                    type = self.__getSqlalchemyType(column[1])
                    column_info.append({
                        'name': column[0],
                        'type': self.__getSqlalchemyType(column[1]),
                        'gql_type': self.__getGqlType(column[1]),
                        'gql_required': 'required=True' if column[3] == 'NO' else 'required=False',
                        'nullable': 'nullable=True' if column[3] == 'YES' else 'nullable=False',
                        'default': 'default=func.now' if column[5] == 'CURRENT_TIMESTAMP' else  (f"default='{column[5]}'" if ((type.startswith("String") or (type.startswith("Date") or True )) and column[5] != None) else f"default={column[5]}"),
                        'comment': f"'{column[6]}'",
                        'primary_key': 'primary_key=True,' if column[4] == 'PRI' else ''
                    })

                    self.__column_graphene_info[column[0]] = self.__getGqlType(column[1])
                return column_info
        except Exception as e:
            logger.error(f"Error occurred while fetching table columns: {e}")
            raise e
        finally:
            pass
    def __getTableIndexes(self,connection, table_name):
        logger.info(f"get table indexes for {table_name}")
        try:
            with connection.cursor() as cursor:
                # 查询表索引信息的 SQL 语句
                query = """
                SHOW INDEX FROM `{}`;
                """.format(table_name)

                # 执行查询
                cursor.execute(query)
                indexes = cursor.fetchall()

                index_info = {}
                for index in indexes:
                    index_name = index[2]
                    column_name = index[4]
                    non_unique = index[1]

                    if index_name not in index_info:
                        index_info[index_name] = {'columns': [], 'unique': non_unique == 0,'gql_columns':[]}

                    index_info[index_name]['columns'].append(column_name)

                    index_info[index_name]['gql_columns'].append(f"para_{column_name}={self.__column_graphene_info[column_name]}")

                return index_info
        except Exception as e:
            logger.error(f"Error occurred while fetching table indexes: {e}")
            raise e
        finally:
            pass

    def getInfo(self):
        return self.info
    def dumpInfo(self):
        print(self.info)
    def loadInfo(self):
        """
        从数据库中加载表、字段、索引信息。
        返回一个字典，包含表名、字段名和索引信息。
        [
            {
                "table_name": "表名",
                "columns": [
                    {
                        "field_name": "字段名",
                        "field_type": "字段类型"
                    }
                ],
                "indexes": [
                    {
                        "index_name": "索引名",
                        "index_type": "索引类型"
                    }
                ]
            }
        ]
        """
        try:
            self.info.clear()
            connection = pymysql.connect(host=self.host, 
                                        port=self.port, 
                                        user=self.user, 
                                        password=self.password, 
                                        database=self.database)
            logger.info(f"load table info of {self.getDatabase()}")
            
            with connection.cursor() as cursor:
                # SHOW TABLES 将列出包含视图，
                # cursor.execute("SHOW TABLES")
                '''
                SELECT TABLE_NAME, TABLE_TYPE
                FROM information_schema.TABLES
                WHERE TABLE_SCHEMA = 'db_pccs_customer' and TABLE_TYPE="BASE TABLE"
                '''
                cursor.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = '{}' and TABLE_TYPE='BASE TABLE'".format(self.database))
                tables = cursor.fetchall()
                for (table_name,) in tables:

                    logger.info(f"get table info for {table_name}")

                    table_info = {
                        "table_name": table_name,
                        "columns": self.__getTableColumns(connection, table_name),
                        "indexes": self.__getTableIndexes(connection, table_name)
                    }

                    self.info.append(table_info)
        except Exception as e:
            logger.error(f"Error loading table info: {e}")
            raise e
        finally:
            pass
    def __repr__(self):
        return f"<DBInfo {self.host}:{self.port}/{self.database}>"
        pass

